#!/usr/bin/python
"""A script to apply a trained model to some data.

"""
import sys
import getopt
import numpy as np

import rsvd

__version__ = rsvd.__version__
__author__  = rsvd.__author__
__license__ = rsvd.__license__

class Usage(Exception):
    def __init__(self, msg):
        self.msg = msg

def usage():
    print """Usage: rsvd_predict [options] model_dir
    
Applies the trained model in <model_dir> to test data.
Two formats are supported:
  text input (<movie_id>\\t<user_id>\\n), or
  numpy record array (dtype `rsvd.rating_t`).

Output is always textual (<movie_id>\\t<user_id>\\t<predicted_rating>\\n).

See <http://code.google.com/p/pyrsvd/> for further information. 

Options:
  -h, --help\tPrint this help
  -o <file>\tThe output file. If '-' or not set stdout is used.
  -i <file>\tThe input file. If '-' or not set stdin is used.

For bug reporting, please mail to:
<peter.prettenhofer@gmail.com>
"""

def arrayIterFD(fd):
    """Wraps a file descriptor and provides a record array
    kind of interface.
    """
    for line in fd:
        yield tuple(map(int,line.rstrip().split()))

def openInputFile(file):
    """Opens the input file. If the file ends with '.txt' it is
    opened in text mode (and wraped via arrayIterFD to create a mock
    record array interface) else it is assumed to be a serialized
    numpy record array. 
    """
    if file.endswith(".txt"):
        f=open(file,'r')
        return arrayIterFD(f)
    else:
        return np.fromfile(file,dtype=rsvd.rating_t)

def main(argv=None):
    if argv is None:
        argv = sys.argv
    try:
        try:
            opts, args = getopt.getopt(argv[1:], "ho:i:", \
                                           ["help"])
        except getopt.error, msg:
             raise Usage(msg)
        
        output_file = "-"
        input_file = "-"

        for o, a in opts:
            if o in ("-h", "--help"):
                usage()
                sys.exit()
            elif o in ("-o"):
                output_file=a
            elif o in ("-i"):
                input_file=a
            else:
                raise Usage("unhandled option")
        
        if len(args) < 1:
            raise Usage("missing arguments")
        model_dir=args[0]
        try:
            output=sys.stdout if output_file is "-" else open(output_file,'r')
            input_iterator=arrayIterFD(sys.stdin) if input is "-" \
                            else openInputFile(input_file)
            
            model=rsvd.RSVD.load(model_dir)
            for example in input_iterator:
                output.write("%d\t%d\t%f\n"%(example[0],\
                             example[1],model(example[0],example[1])))
                        
        except Exception, e:
            print >>sys.stderr, "Error: ",e
        
    except Usage, err:
        print >>sys.stderr, err.msg
        usage()
        return 2

if __name__ == "__main__":
    sys.exit(main())
